# Lint as: python3
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
# pylint: disable=line-too-long
"""
train_hello_world_model.ipynb

Original script automatically generated by Colaboratory
 (File -> Download -> xxx.py).

Original file is located at
    https://colab.research.google.com/github/tensorflow/tflite-micro/blob/main/tensorflow/lite/micro/examples/hello_world/train/train_hello_world_model.ipynb

# Train a Simple TensorFlow Lite for Microcontrollers model

This notebook demonstrates the process of training a 2.5 kB model using
 TensorFlow and converting it for use with TensorFlow Lite for
 Microcontrollers.

Deep learning networks learn to model patterns in underlying data.
 Here, we're going to train a network to model data generated by a
 [sine](https://en.wikipedia.org/wiki/Sine) function.
 This will result in a model that can take a value, `x`,
 and predict its sine, `y`.

The model created in this notebook is used in the
 [hello_world](https://github.com/tensorflow/tflite-micro/blob/main/tensorflow/lite/micro/examples/hello_world)
 example for [TensorFlow Lite for MicroControllers](https://www.tensorflow.org/lite/microcontrollers/overview).

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/tflite-micro/blob/main/tensorflow/lite/micro/examples/hello_world/train/train_hello_world_model.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/tflite-micro/blob/main/tensorflow/lite/micro/examples/hello_world/train/train_hello_world_model.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

Tested with Python 3.7.5 on Ubuntu 18.04.06 with VirtualBox VM
"""
# pylint: enable=line-too-long

import os
import sys
import subprocess
import math
import textwrap
import traceback
from typing import NamedTuple, Optional

## Shell command utility method


def run_command(cmd: str) -> None:
  """
  Run a shell command.

  Args: str containing command to run

  Throws: CalledProcessError if the command fails
  """
  result = subprocess.run(cmd,
                          shell=True,
                          check=True,
                          text=True,
                          stdout=subprocess.PIPE,
                          stderr=subprocess.STDOUT)
  print(result.stdout)


## Setup Environment

if __name__ == '__main__':
  # Install Dependencies

  run_command('pip install --upgrade pip')
  run_command('pip install tensorflow==2.7.0')  # tensorflow installs numpy
  run_command('pip install pandas')
  run_command('pip install matplotlib')

  # Import Dependencies
  # TensorFlow is an open source machine learning library
  import tensorflow as tf
  # Keras is TensorFlow's high-level API for deep learning
  from tensorflow import keras
  # Numpy is a math library
  import numpy as np
  # Pandas is a data manipulation library
  import pandas as pd
  # Matplotlib is a graphing library
  import matplotlib.pyplot as plt

## Configure Defaults


class ConfigData(NamedTuple):
  """Configuration data"""

  # Define RNG seed
  SEED_VALUE = 1

  # Define paths to model files
  MODEL_TF = 'model_tf'
  MODEL_NO_QUANT_TFLITE = 'model_no_quant.tflite'
  MODEL_QUANT_TFLITE = 'model_quant.tflite'

  # Number of sample datapoints
  SAMPLES = 1000
  # Number of representative samples for quantized model conversion
  REPRESENTATIVE_SAMPLES = 500

  # Define paths to plot files
  PLOT_GEN_DATA = 'generated_data'
  PLOT_GEN_DATA_NOISY = 'generated_data_noisy'
  PLOT_GEN_DATA_SPLIT = 'generated_data_split'
  PLOT_TRAIN_PREDICTION = 'train_prediction'
  PLOT_TRAIN_VALIDATION = 'train_validation'
  PLOT_COMPARE_PREDICTION = 'compare_prediction'

  # Define width of plot titles
  PLOT_TITLE_WIDTH = 50

  # Loss and MAE limits
  TRAIN_LOSS = 0.0121
  TRAIN_MAE = 0.0885
  VAL_LOSS = 0.0111
  VAL_MAE = 0.0856
  TEST_LOSS = 0.0108
  TEST_MAE = 0.0832
  TEST_QUANT_LOSS = 0.0118


class Dataset(NamedTuple):
  x_train: np.ndarray
  x_validate: np.ndarray
  x_test: np.ndarray
  y_train: np.ndarray
  y_validate: np.ndarray
  y_test: np.ndarray


## Script utility methods


# Set seed for experiment reproducibility
def set_rng_seed(config: ConfigData) -> None:
  """Set RNG seed"""
  np.random.seed(config.SEED_VALUE)
  tf.random.set_seed(config.SEED_VALUE)


def save_plot(config: ConfigData, title: Optional[str], filename: str) -> None:
  """
  Save the current matplotlib.pyplot figure

  Args:
    config: script configuration data
    title: optional string for figure title
    filename: name of file to store image to (no suffix required)

  Throws: OSError, ValueError
  """
  if title is not None:
    tw = textwrap.TextWrapper(width=config.PLOT_TITLE_WIDTH,
                              break_long_words=False,
                              break_on_hyphens=False)
    plt.title(tw.fill(title))
  plt.savefig(filename)


## Dataset


def generate_data(config: ConfigData) -> Dataset:
  """
  ### 1. Generate Data
  The code in the following cell will generate a set of random `x` values,
  calculate their sine values, and display them on a graph.

  Throws: AssertionError
  """
  # Generate a uniformly distributed set of random numbers in the range from
  # 0 to 2π, which covers a complete sine wave oscillation
  x_values = np.random.uniform(low=0, high=2 * math.pi,
                               size=config.SAMPLES).astype(np.float32)

  # Shuffle the values to guarantee they're not in order
  np.random.shuffle(x_values)

  # Calculate the corresponding sine values
  y_values = np.sin(x_values).astype(np.float32)

  # Plot our data. The 'b.' argument tells the library to print blue dots.
  plt.plot(x_values, y_values, 'b.')
  title = 'Simple sine wave'
  save_plot(config, title=title, filename=config.PLOT_GEN_DATA)

  ### 2. Add Noise
  # Since it was generated directly by the sine function, our data fits a nice,
  # smooth curve.

  # However, machine learning models are good at extracting underlying meaning
  # from messy, real world data. To demonstrate this, we can add some noise to
  # our data to approximate something more life-like.

  # In the following cell, we'll add some random noise to each value, then draw
  # a new graph:

  # Add a small random number to each y value
  y_values += 0.1 * np.random.randn(*y_values.shape)

  # Plot our data
  plt.plot(x_values, y_values, 'b.')
  save_plot(config, title='Add Noise', filename=config.PLOT_GEN_DATA_NOISY)

  ### 3. Split the Data
  # We now have a noisy dataset that approximates real world data. We'll be
  # using this to train our model.

  # To evaluate the accuracy of the model we train, we'll need to compare its
  # predictions to real data and check how well they match up. This evaluation
  # happens during training (where it is referred to as validation) and after
  # training (referred to as testing) It's important in both cases that we use
  # fresh data that was not already used to train the model.

  # The data is split as follows:
  #   1. Training: 60%
  #   2. Validation: 20%
  #   3. Testing: 20%

  # The following code will split our data and then plots each set as a
  # different color:

  # We'll use 60% of our data for training and 20% for testing. The remaining
  # 20% will be used for validation. Calculate the indices of each section.
  train_split = int(0.6 * config.SAMPLES)
  test_split = int(0.2 * config.SAMPLES + train_split)

  # Use np.split to chop our data into three parts.
  # The second argument to np.split is an array of indices where the data will
  # be split. We provide two indices, so the data will be divided into three
  # chunks.
  splits = np.split(x_values, [train_split, test_split])
  x_train, x_test, x_validate = splits[0], splits[1], splits[2]
  splits = np.split(y_values, [train_split, test_split])
  y_train, y_test, y_validate = splits[0], splits[1], splits[2]

  # Double check that our splits add up correctly
  assert (x_train.size + x_validate.size +
          x_test.size) == config.SAMPLES, f'x dataset size != {config.SAMPLES}'
  assert (y_train.size + y_validate.size +
          y_test.size) == config.SAMPLES, f'y dataset size != {config.SAMPLES}'

  # Plot the data in each partition in different colors:
  plt.plot(x_train, y_train, 'b.', label='Train')
  plt.plot(x_test, y_test, 'r.', label='Test')
  plt.plot(x_validate, y_validate, 'y.', label='Validate')
  plt.legend()
  save_plot(config, title='Split Data', filename=config.PLOT_GEN_DATA_SPLIT)

  return Dataset(x_train, x_validate, x_test, y_train, y_validate, y_test)


def create_model() -> tf.keras.Model:
  """
  ## Training
  ## Training a Larger Model

  ### 1. Design the Model
  To make our model bigger, let's add an additional layer of neurons.
  The following cell redefines our model in the same way as earlier, but with
  16 neurons in the first layer and an additional layer of 16 neurons in the
  middle:
  """

  model = tf.keras.Sequential()

  # First layer takes a scalar input and feeds it through 16 "neurons". The
  # neurons decide whether to activate based on the 'relu' activation function.
  model.add(keras.layers.Dense(16, activation='relu', input_shape=(1, )))

  # The new second and third layer will help the network learn more complex
  # representations
  model.add(keras.layers.Dense(16, activation='relu'))

  # Final layer is a single neuron, since we want to output a single value
  model.add(keras.layers.Dense(1))

  # Compile the model using the standard 'adam' optimizer and the mean squared
  # error or 'mse' loss function for regression.
  model.compile(optimizer='adam', loss='mse', metrics=['mae'])

  return model


def train_model(config: ConfigData, model: tf.keras.Model,
                ds: Dataset) -> None:
  """
  ### 2. Train the Model ###
  We'll now train and save the new model.

  Throws: ImportError, AssertionError
  """

  # Train the model
  history = model.fit(ds.x_train,
                      ds.y_train,
                      epochs=500,
                      batch_size=64,
                      validation_data=(ds.x_validate, ds.y_validate),
                      verbose=2)

  # Save the model to disk
  model.save(config.MODEL_TF)

  # ### 3. Plot Metrics
  # Each training epoch, the model prints out its loss and mean absolute error
  # for training and validation. You can read this in the output above
  # (note that your exact numbers may differ):

  # Epoch 500/500
  # 10/10 - 0s - loss: 0.0121 - mae: 0.0884 -
  #  val_loss: 0.0111 - val_mae: 0.0856 - 21ms/epoch - 2ms/step

  # You can see that we've already got a huge improvement - validation loss
  # has dropped from 0.15 to 0.01, and validation MAE has dropped
  # from 0.33 to 0.08.

  # The following cell will print the same graphs we used to evaluate our
  # original model, but showing our new training history:

  # Draw a graph of the loss, which is the distance between
  # the predicted and actual values during training and validation.
  train_loss = history.history['loss']
  val_loss = history.history['val_loss']

  epochs = range(1, len(train_loss) + 1)

  # Exclude the first few epochs so the graph is easier to read
  skip = 100

  plt.figure(figsize=(10, 4))
  plt.subplot(1, 2, 1)

  plt.plot(epochs[skip:], train_loss[skip:], 'g.', label='Training loss')
  plt.plot(epochs[skip:], val_loss[skip:], 'b.', label='Validation loss')
  plt.title('Training and validation loss')
  plt.xlabel('Epochs')
  plt.ylabel('Loss')
  plt.legend()

  plt.subplot(1, 2, 2)

  # Draw a graph of mean absolute error, which is another way of
  # measuring the amount of error in the prediction.
  train_mae = history.history['mae']
  val_mae = history.history['val_mae']

  plt.plot(epochs[skip:], train_mae[skip:], 'g.', label='Training MAE')
  plt.plot(epochs[skip:], val_mae[skip:], 'b.', label='Validation MAE')
  plt.title('Training and validation mean absolute error')
  plt.xlabel('Epochs')
  plt.ylabel('MAE')
  plt.legend()
  save_plot(config, title=None, filename=config.PLOT_TRAIN_VALIDATION)

  assert train_loss[-1] <= config.TRAIN_LOSS, 'Training loss too large'
  assert val_loss[-1] <= config.VAL_LOSS, 'Validation loss too large'
  assert train_mae[-1] <= config.TRAIN_MAE, 'Training MAE too large'
  assert val_mae[-1] <= config.VAL_MAE, 'Validation MAE too large'

  # Great results! From these graphs, we can see several exciting things:

  # *   The overall loss and MAE are much better than our previous network
  # *   Metrics are better for validation than training, which means the
  #     network is not overfitting

  # The reason the metrics for validation are better than those for training
  # is that validation metrics are calculated at the end of each epoch,
  # while training metrics are calculated throughout the epoch, so validation
  # happens on a model that has been trained slightly longer.

  # This all means our network seems to be performing well! To confirm, let's
  # check its predictions against the test dataset we set aside earlier:

  # Calculate and print the loss on our test dataset
  print('Evaluation loss:')
  test_loss, test_mae = model.evaluate(ds.x_test, ds.y_test)

  # Make predictions based on our test dataset
  y_test_pred = model.predict(ds.x_test)

  # Graph the predictions against the actual values
  plt.clf()
  plt.title('Comparison of predictions and actual values')
  plt.plot(ds.x_test, ds.y_test, 'b.', label='Actual values')
  plt.plot(ds.x_test, y_test_pred, 'r.', label='TF predicted')
  plt.legend()
  save_plot(config, title=None, filename=config.PLOT_TRAIN_PREDICTION)

  assert test_loss <= config.TEST_LOSS, 'Test loss too large'
  assert test_mae <= config.TEST_MAE, 'Test MAE too large'

  # Much better! The evaluation metrics we printed show that the model has a
  # low loss and MAE on the test data, and the predictions line up visually
  # with our data fairly well.

  # The model isn't perfect; its predictions don't form a smooth sine curve.
  # For instance, the line is almost straight when `x` is between 4.2 and 5.2.
  # If we wanted to go further, we could try further increasing the capacity
  # of the model, perhaps using some techniques to defend from overfitting.

  # However, an important part of machine learning is *knowing when to stop*.
  # This model is good enough for our use case - which is to make some LEDs
  # blink in a pleasing pattern.


def create_tflite_models(config: ConfigData, ds: Dataset) -> None:
  """
  ## Generate a TensorFlow Lite Model

  ### 1. Generate Models with or without Quantization
  We now have an acceptably accurate model. We'll use the
  [TensorFlow Lite Converter](https://www.tensorflow.org/lite/convert)
  to convert the model into a special, space-efficient format for use
  on memory-constrained devices.

  Since this model is going to be deployed on a microcontroller, we want
  it to be as tiny as possible! One technique for reducing the size of a
  model is called [quantization]
  (https://www.tensorflow.org/lite/performance/post_training_quantization).
  It reduces the precision of the model's weights, and possibly the
  activations (output of each layer) as well, which saves memory, often
  without much impact on accuracy. Quantized models also run faster,
  since the calculations required are simpler.

  In the following cell, we'll convert the model twice: once with
  quantization, once without.

  Throws: OSError, ValueError
  """
  # Convert the model to the TensorFlow Lite format without quantization
  converter = tf.lite.TFLiteConverter.from_saved_model(config.MODEL_TF)
  model_no_quant_tflite = converter.convert()

  # Save the model to disk
  with open(config.MODEL_NO_QUANT_TFLITE, 'wb') as fp:
    fp.write(model_no_quant_tflite)

  # Convert the model to the TensorFlow Lite format with quantization
  def representative_dataset():
    for i in range(config.REPRESENTATIVE_SAMPLES):
      yield [ds.x_train[i].reshape(1, 1)]

  # Set the optimization flag.
  converter.optimizations = [tf.lite.Optimize.DEFAULT]
  # Enforce integer only quantization
  converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
  converter.inference_input_type = tf.int8
  converter.inference_output_type = tf.int8
  # Provide a representative dataset to ensure we quantize correctly.
  converter.representative_dataset = representative_dataset
  model_quant_tflite = converter.convert()

  # Save the model to disk
  with open(config.MODEL_QUANT_TFLITE, 'wb') as fp:
    fp.write(model_quant_tflite)


### 2. Compare Model Performance

# To prove these models are accurate even after conversion and quantization,
# we'll compare their predictions and loss on our test dataset.

# **Helper functions**

# We define the `predict` (for predictions) and `evaluate` (for loss)
# functions for TFLite models.
# Note: These are already included in a TF model, but not in  a TFLite model.


def predict_tflite(tflite_model: bytes, x_test: np.ndarray) -> np.ndarray:
  """
  Compute prediction array for given Tensorflow Lite model

  Throws: ValueError
  """
  # Prepare the test data
  x_test_ = x_test.copy()
  x_test_ = x_test_.reshape((x_test.size, 1))
  x_test_ = x_test_.astype(np.float32)

  # Initialize the TFLite interpreter
  interpreter = tf.lite.Interpreter(model_content=tflite_model)
  interpreter.allocate_tensors()

  input_details = interpreter.get_input_details()[0]
  output_details = interpreter.get_output_details()[0]

  # If required, quantize the input layer (from float to integer)
  input_scale, input_zero_point = input_details['quantization']
  if (input_scale, input_zero_point) != (0.0, 0):
    x_test_ = x_test_ / input_scale + input_zero_point
    x_test_ = x_test_.astype(input_details['dtype'])

  # Invoke the interpreter
  y_pred = np.empty(x_test_.size, dtype=output_details['dtype'])
  for i in range(len(x_test_)):
    interpreter.set_tensor(input_details['index'], [x_test_[i]])
    interpreter.invoke()
    y_pred[i] = interpreter.get_tensor(output_details['index'])[0]

  # If required, dequantized the output layer (from integer to float)
  output_scale, output_zero_point = output_details['quantization']
  if (output_scale, output_zero_point) != (0.0, 0):
    y_pred = y_pred.astype(np.float32)
    y_pred = (y_pred - output_zero_point) * output_scale

  return y_pred


def evaluate_tflite(model: tf.keras.Model, tflite_model: bytes,
                    x_test: np.ndarray, y_true: np.ndarray) -> float:
  y_pred = predict_tflite(tflite_model, x_test)
  loss_function = tf.keras.losses.get(model.loss)
  loss = loss_function(y_true, y_pred).numpy()
  return loss


def compare_predictions(model: tf.keras.Model, model_no_quant_tflite: bytes,
                        model_quant_tflite: bytes, ds: Dataset,
                        config: ConfigData) -> None:
  """**1. Predictions**"""

  # Calculate predictions
  y_test_pred_tf = model.predict(ds.x_test)
  y_test_pred_no_quant_tflite = predict_tflite(model_no_quant_tflite,
                                               ds.x_test)
  y_test_pred_quant_tflite = predict_tflite(model_quant_tflite, ds.x_test)

  # Compare predictions
  plt.clf()
  plt.title('Comparison of various models against actual values')
  plt.plot(ds.x_test, ds.y_test, 'bo', label='Actual values')
  plt.plot(ds.x_test, y_test_pred_tf, 'ro', label='TF predictions')
  plt.plot(ds.x_test,
           y_test_pred_no_quant_tflite,
           'bx',
           label='TFLite predictions')
  plt.plot(ds.x_test,
           y_test_pred_quant_tflite,
           'gx',
           label='TFLite quantized predictions')
  plt.legend()
  save_plot(config, title=None, filename=config.PLOT_COMPARE_PREDICTION)


def compare_loss(model: tf.keras.Model, model_no_quant_tflite: bytes,
                 model_quant_tflite: bytes, ds: Dataset,
                 config: ConfigData) -> None:
  """
  **2. Loss (MSE/Mean Squared Error)**

  Throws: AssertionError
  """

  # Calculate loss
  loss_tf, _ = model.evaluate(ds.x_test, ds.y_test, verbose=0)
  loss_no_quant_tflite = evaluate_tflite(model,
                                         tflite_model=model_no_quant_tflite,
                                         x_test=ds.x_test,
                                         y_true=ds.y_test)
  loss_quant_tflite = evaluate_tflite(model,
                                      tflite_model=model_quant_tflite,
                                      x_test=ds.x_test,
                                      y_true=ds.y_test)

  # Compare loss
  df = pd.DataFrame.from_records(
      [['TensorFlow', loss_tf], ['TensorFlow Lite', loss_no_quant_tflite],
       ['TensorFlow Lite Quantized', loss_quant_tflite]],
      columns=['Model', 'Loss/MSE'],
      index='Model').round(4)
  print(df)

  assert loss_quant_tflite <= config.TEST_QUANT_LOSS, \
    'Test loss (quantized) too large'


def compare_sizes(config: ConfigData) -> None:
  """
  **3. Size**

  Throws: AssertionError
  """

  # Calculate size
  size_tf = os.path.getsize(config.MODEL_TF)
  size_no_quant_tflite = os.path.getsize(config.MODEL_NO_QUANT_TFLITE)
  size_quant_tflite = os.path.getsize(config.MODEL_QUANT_TFLITE)

  # Compare size
  df = pd.DataFrame.from_records(
      [['TensorFlow', f'{size_tf} bytes', ''],
       [
           'TensorFlow Lite', f'{size_no_quant_tflite} bytes ',
           f'(reduced by {size_tf - size_no_quant_tflite} bytes)'
       ],
       [
           'TensorFlow Lite Quantized', f'{size_quant_tflite} bytes',
           f'(reduced by {size_no_quant_tflite - size_quant_tflite} bytes)'
       ]],
      columns=['Model', 'Size', ''],
      index='Model')
  print(df)

  assert size_quant_tflite < size_no_quant_tflite, \
    'No quantized model size reduction'


# **Summary**

# We can see from the predictions (graph) and loss (table) that the original
# TF model, the TFLite model, and the quantized TFLite model are all close
# enough to be indistinguishable - even though they differ in size (table).
# This implies that the quantized (smallest) model is ready to use!

# *Note: The quantized (integer) TFLite model is just 300 bytes smaller than
# the original (float) TFLite model - a tiny reduction in size! This is
# because the model is already so small that quantization has little effect.
# Complex models with more weights, can have upto a 4x reduction in size!*


def main():
  """main entry point"""
  config = ConfigData()
  set_rng_seed(config)
  dataset = generate_data(config)
  model = create_model()
  train_model(config, model=model, ds=dataset)
  create_tflite_models(config, dataset)
  with open(config.MODEL_QUANT_TFLITE, 'rb') as fp:
    model_quant = fp.read()
  with open(config.MODEL_NO_QUANT_TFLITE, 'rb') as fp:
    model_no_quant = fp.read()
  compare_predictions(model,
                      model_no_quant_tflite=model_no_quant,
                      model_quant_tflite=model_quant,
                      ds=dataset,
                      config=config)
  compare_loss(model,
               model_no_quant_tflite=model_no_quant,
               model_quant_tflite=model_quant,
               ds=dataset,
               config=config)
  compare_sizes(config=config)


if __name__ == '__main__':
  try:
    main()
    sys.exit(0)
  except (subprocess.CalledProcessError, AssertionError, OSError, ValueError,
          ImportError) as ex:
    traceback.print_tb(sys.exc_info()[2])
    print(str(ex))
    sys.exit(1)
